/* Copyright 2019 Luca Fedeli
 *
 * This file is part of WarpX.
 *
 * License: BSD-3-Clause-LBNL
 */
#ifndef WARPX_breit_wheeler_engine_wrapper_h_
#define WARPX_breit_wheeler_engine_wrapper_h_

#include "BreitWheelerEngineWrapper_fwd.H"

#include "QedChiFunctions.H"
#include "QedWrapperCommons.H"
#include "Utils/WarpXConst.H"

#include <AMReX_Extension.H>
#include <AMReX_GpuContainers.H>
#include <AMReX_GpuQualifiers.H>
#include <AMReX_REAL.H>
#include <AMReX_Random.H>

#include <picsar_qed/containers/picsar_array.hpp>
#include <picsar_qed/math/cmath_overloads.hpp>
#include <picsar_qed/math/math_constants.h>
#include <picsar_qed/math/vec_functions.hpp>
#include <picsar_qed/physics/breit_wheeler/breit_wheeler_engine_core.hpp>
#include <picsar_qed/physics/breit_wheeler/breit_wheeler_engine_tables.hpp>
#include <picsar_qed/physics/gamma_functions.hpp>
#include <picsar_qed/physics/phys_constants.h>
#include <picsar_qed/physics/unit_conversion.hpp>

#include <cmath>
#include <vector>

namespace amrex { struct RandomEngine; }

// Aliases =============================
using BW_dndt_table_params =
    picsar::multi_physics::phys::breit_wheeler::
    dndt_lookup_table_params<amrex::ParticleReal>;

using BW_dndt_table =
    picsar::multi_physics::phys::breit_wheeler::
    dndt_lookup_table<
    amrex::ParticleReal,
    amrex::Gpu::DeviceVector<amrex::ParticleReal>>;

using BW_dndt_table_view = BW_dndt_table::view_type;

using BW_pair_prod_table_params =
    picsar::multi_physics::phys::breit_wheeler::
    pair_prod_lookup_table_params<amrex::ParticleReal>;

using BW_pair_prod_table =
    picsar::multi_physics::phys::breit_wheeler::
    pair_prod_lookup_table<
    amrex::ParticleReal,
    amrex::Gpu::DeviceVector<amrex::ParticleReal>>;

using BW_pair_prod_table_view = BW_pair_prod_table::view_type;

struct PicsarBreitWheelerCtrl
{
    BW_dndt_table_params dndt_params;
    BW_pair_prod_table_params pair_prod_params;
};

// Functors ==================================

// These functors allow using the core elementary functions of the library.
// They are generated by a factory class (BreitWheelerEngine, see below).
// They can be included in GPU kernels.

/**
 * Functor to initialize the optical depth of photons for the
 * Breit-Wheeler process
 */
class BreitWheelerGetOpticalDepth
{
public:
    /**
     * Constructor does nothing because optical depth initialization
     * does not require control parameters or lookup tables.
     */
    BreitWheelerGetOpticalDepth () = default;

    /**
     * () operator is just a thin wrapper around a very simple function to
     * generate the optical depth. It can be used on GPU.
     */
    AMREX_GPU_HOST_DEVICE
    AMREX_FORCE_INLINE
    amrex::ParticleReal operator() (amrex::RandomEngine const& engine) const noexcept
    {
        namespace pxr_bw = picsar::multi_physics::phys::breit_wheeler;

        //A random number in [0,1) should be provided as an argument.
        return pxr_bw::get_optical_depth(amrex::Random(engine));
    }
};
//____________________________________________

/**
 * Functor to evolve the optical depth of photons due to the
 * Breit-Wheeler process
 */
class BreitWheelerEvolveOpticalDepth
{
public:

    /**
     * Default constructor: it leaves the functor in a non-initialized state.
     */
    BreitWheelerEvolveOpticalDepth () = default;


    /**
     * Constructor to be used to initialize the functor.
     *
     * @param[in] table_view a view of a BW_dndt_table lookup table
     * @param[in] bw_minimum_chi_phot the minimum quantum parameter to evolve the optical depth
     */
    BreitWheelerEvolveOpticalDepth (
        const BW_dndt_table_view table_view,
        const amrex::ParticleReal bw_minimum_chi_phot):
        m_table_view{table_view}, m_bw_minimum_chi_phot{bw_minimum_chi_phot}{}

    /**
     * Evolves the optical depth. It can be used on GPU.
     * If the  quantum parameter parameter of the photon is
     * < bw_minimum_chi_phot, the method returns immediately.
     * The method returns also if the energy of the photon is insufficient
     * to generate a pair.
     *
     * @param[in] ux,uy,uz gamma*v components of the photon.
     * @param[in] ex,ey,ez electric field components (SI units)
     * @param[in] bx,by,bz magnetic field components (SI units)
     * @param[in] dt timestep (SI units)
     * @param[in,out] opt_depth optical depth of the photon.
     * @return a flag which is 1 if chi_phot was out of table
     */
    AMREX_GPU_DEVICE
    AMREX_FORCE_INLINE
    int operator()(
        const amrex::ParticleReal ux, const amrex::ParticleReal uy,
        const amrex::ParticleReal uz, const amrex::ParticleReal ex,
        const amrex::ParticleReal ey, const amrex::ParticleReal ez,
        const amrex::ParticleReal bx, const amrex::ParticleReal by,
        const amrex::ParticleReal bz, const amrex::Real dt,
        amrex::ParticleReal& opt_depth) const noexcept
    {
        namespace pxr_m = picsar::multi_physics::math;
        namespace pxr_p = picsar::multi_physics::phys;
        namespace pxr_bw = picsar::multi_physics::phys::breit_wheeler;

        constexpr amrex::ParticleReal m_e = PhysConst::m_e;
        const auto u_norm = std::sqrt(ux*ux + uy*uy + uz*uz);
        const auto energy = u_norm*m_e*PhysConst::c;

        const auto px = m_e*ux;
        const auto py = m_e*uy;
        const auto pz = m_e*uz;

        const auto chi_phot = QedUtils::chi_photon(
            px, py, pz, ex, ey, ez, bx, by, bz);

        //Optical depth is not evolved for photons having less energy than what is
        //required to generate a pair or a quantum parameter smaller than
        //m_bw_minimum_chi_phot
        const auto gamma_photon = pxr_p::compute_gamma_photon<
                amrex::ParticleReal, pxr_p::unit_system::SI>(
                    px, py, pz);
        if (gamma_photon < pxr_m::two<amrex::ParticleReal> ||
            chi_phot < m_bw_minimum_chi_phot)
            return 0;

        const auto is_out = pxr_bw::evolve_optical_depth<
            amrex::ParticleReal,
            BW_dndt_table_view,
            pxr_p::unit_system::SI>(
                energy, chi_phot, dt, opt_depth, m_table_view);

        return is_out;
    }

private:
    BW_dndt_table_view m_table_view;
    amrex::ParticleReal m_bw_minimum_chi_phot;
};

/**
 * Functor to generate a pair via the
 * Breit-Wheeler process
 */
class BreitWheelerGeneratePairs
{
public:

    /**
     * Default constructor: it leaves the functor in a non-initialized state.
     */
    BreitWheelerGeneratePairs () = default;

    /**
     * Constructor acquires pointers to control parameters and
     * lookup tables data.
     * lookup_table uses non-owning vectors under the hood. So no new data
     * allocations should be triggered on GPU
     *
     * @param[in] table_view a BW_pair_prod_table_view
     */
    BreitWheelerGeneratePairs (const BW_pair_prod_table_view table_view):
        m_table_view{table_view}{}

    /**
     * Generates pairs according to Breit Wheeler process.
     * It can be used on GPU.
     * Warning: the energy of the photon must be > 2mec^2, but it is not checked
     * in this method.
     *
     * @param[in] ux,uy,uz gamma*v components of the photon (SI units)
     * @param[in] ex,ey,ez electric field components (SI units)
     * @param[in] bx,by,bz magnetic field components (SI units)
     * @param[out] e_ux,e_uy,e_uz gamma*v components of generated electron (SI units)
     * @param[out] p_ux,p_uy,p_uz gamma*v components of generated positron (SI units)
     * @param[in] engine random number generator engine
     * @return a flag which is 1 if chi_photon was out of table
     */
    AMREX_GPU_DEVICE
    AMREX_FORCE_INLINE
    int operator()(
    const amrex::ParticleReal ux, const amrex::ParticleReal uy,
    const amrex::ParticleReal uz, const amrex::ParticleReal ex,
    const amrex::ParticleReal ey, const amrex::ParticleReal ez,
    const amrex::ParticleReal bx, const amrex::ParticleReal by,
    const amrex::ParticleReal bz, amrex::ParticleReal& e_ux,
    amrex::ParticleReal& e_uy, amrex::ParticleReal& e_uz,
    amrex::ParticleReal& p_ux, amrex::ParticleReal& p_uy,
    amrex::ParticleReal& p_uz,
    amrex::RandomEngine const& engine) const noexcept
    {
        using namespace amrex;
        namespace pxr_m = picsar::multi_physics::math;
        namespace pxr_p = picsar::multi_physics::phys;
        namespace pxr_bw = picsar::multi_physics::phys::breit_wheeler;

        const auto rand_zero_one_minus_epsi = amrex::Random(engine);

        constexpr ParticleReal me = PhysConst::m_e;
        constexpr ParticleReal one_over_me = 1._prt/me;

        // Particle momentum is stored as gamma * velocity.
        // Convert to m * gamma * velocity
        auto px = ux*me;
        auto py = uy*me;
        auto pz = uz*me;

        const auto chi_photon = QedUtils::chi_photon(
            px, py, pz, ex, ey, ez, bx, by, bz);

        const auto momentum_photon = pxr_m::vec3<amrex::ParticleReal>{px, py, pz};
        auto momentum_ele = pxr_m::vec3<amrex::ParticleReal>();
        auto momentum_pos = pxr_m::vec3<amrex::ParticleReal>();

        const auto is_out = pxr_bw::generate_breit_wheeler_pairs<
            amrex::ParticleReal,
            BW_pair_prod_table_view,
            pxr_p::unit_system::SI>(
                chi_photon, momentum_photon,
                rand_zero_one_minus_epsi,
                m_table_view,
                momentum_ele, momentum_pos);

        e_ux = momentum_ele[0]*one_over_me;
        e_uy = momentum_ele[1]*one_over_me;
        e_uz = momentum_ele[2]*one_over_me;
        p_ux = momentum_pos[0]*one_over_me;
        p_uy = momentum_pos[1]*one_over_me;
        p_uz = momentum_pos[2]*one_over_me;

        return is_out;
    }

private:
    BW_pair_prod_table_view m_table_view;
};

// Factory class =============================

/**
 * Wrapper for the Breit Wheeler engine of the PICSAR library
 */
class BreitWheelerEngine
{
public:
    /**
     * Constructor requires no arguments.
     */
    BreitWheelerEngine () = default;

    /**
     * Builds the functor to initialize the optical depth
     */
    BreitWheelerGetOpticalDepth build_optical_depth_functor () const;

    /**
     * Builds the functor to evolve the optical depth
     */
    BreitWheelerEvolveOpticalDepth build_evolve_functor () const;

    /**
     * Builds the functor to generate the pairs
     */
    BreitWheelerGeneratePairs build_pair_functor () const;

    /**
     * Checks if the optical tables are properly initialized
     */
    bool are_lookup_tables_initialized () const;

    /**
     * Export lookup tables data into a raw binary Vector
     *
     * @return the data in binary format. The Vector is empty if tables were
     * not previously initialized.
     */
    std::vector<char> export_lookup_tables_data () const;

    /**
     * Init lookup tables from raw binary data.
     *
     * @param[in] raw_data a vector of char
     * @param[in] bw_minimum_chi_phot minimum chi parameter to evolve the optical depth of a photon
     * @return true if it succeeds, false if it cannot parse raw_data
     */
    bool init_lookup_tables_from_raw_data (
        const std::vector<char>& raw_data,
        const amrex::ParticleReal bw_minimum_chi_phot);

    /**
     * Init lookup tables using built-in (low resolution) tables
     *
     * @param[in] bw_minimum_chi_phot minimum chi parameter to evolve the optical depth of a photon
     */
    void init_builtin_tables(const amrex::ParticleReal bw_minimum_chi_phot);

    /**
     * Computes the lookup tables. It does nothing unless WarpX is compiled with QED_TABLE_GEN=TRUE
     *
     * @param[in] ctrl control params to generate the tables
     * @param[in] bw_minimum_chi_phot minimum chi parameter to evolve the optical depth of a photon
     */
    void compute_lookup_tables (const PicsarBreitWheelerCtrl ctrl,
        const amrex::ParticleReal bw_minimum_chi_phot);

    /**
     * gets default values for the control parameters
     *
     * @return default control params to generate the tables
     */
    PicsarBreitWheelerCtrl get_default_ctrl() const;

    amrex::ParticleReal get_minimum_chi_phot() const;

private:
    bool m_lookup_tables_initialized = false;

    //Variables to store the minimum chi parameters to enable
    //Quantum Synchrotron process
    amrex::ParticleReal m_bw_minimum_chi_phot;

    BW_dndt_table m_dndt_table;
    BW_pair_prod_table m_pair_prod_table;

    void init_builtin_dndt_table();
    void init_builtin_pair_prod_table();


};

//============================================

#endif //WARPX_breit_wheeler_engine_wrapper_H_
